library("dslabs")
library("ggplot2")
library("knitr")
library("tidymodels")
library("tidyr")
library("dplyr")
theme_set(theme_minimal())
The goal of clustering is to discover distinct groups within a dataset. In an ideal clustering, samples are very different between groups, but relatively similar within groups. At the end of a clustering routine, \(K\) clusters have been identified, and each sample is assigned to one of these \(K\) clusters. \(K\) must be chosen by the user.
Clustering gives a compressed representation of the dataset. Therefore, clustering is useful for getting a quick overview of the high-level structure in a dataset.
For example, clustering can be used in the following applications,
Here is an animation from the tidymodels page on \(K\)-means,
Note that, since we have to take an average for each coordinate, we require that our data be quantitative, not categorical.
We illustrate this idea using the movielens dataset from the reading. This dataset has ratings (0.5 to 5) given by 671 users across 9066 movies. We can think of this as a matrix of movies vs. users, with ratings within the entries. For simplicity, we filter down to only the 50 most frequently rated movies. We will assume that if a user never rated a movie, they would have given that movie a zero. We’ve skipped a few steps used in the reading (subtracting movie / user averages and filtering to only active users), but the overall results are comparable.
data("movielens")
frequently_rated <- movielens %>%
group_by(movieId) %>%
summarize(n=n()) %>%
top_n(50, n) %>%
pull(movieId)
movie_mat <- movielens %>%
filter(movieId %in% frequently_rated) %>%
select(title, userId, rating) %>%
pivot_wider(title, names_from = userId, values_from = rating, values_fill = 0)
movie_mat[1:10, 1:20]
## # A tibble: 10 x 20
## title `2` `3` `4` `5` `6` `7` `8` `9` `10` `11` `12` `13`
## <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 Seve~ 4 0 0 0 0 0 5 3 0 0 0 2.5
## 2 Usua~ 4 0 0 0 0 0 5 0 5 5 0 0
## 3 Brav~ 4 4 0 0 0 5 4 0 0 0 0 4
## 4 Apol~ 5 0 0 4 0 0 0 0 0 0 0 0
## 5 Pulp~ 4 4.5 5 0 0 0 4 0 0 5 0 3.5
## 6 Forr~ 3 5 5 4 0 3 4 0 0 0 0 5
## 7 Lion~ 3 0 5 4 0 3 0 0 0 0 0 0
## 8 Mask~ 3 0 4 4 0 3 0 0 0 0 0 0
## 9 Speed 3 2.5 0 4 0 3 0 0 0 0 0 0
## 10 Fugi~ 3 0 0 0 0 0 4.5 0 0 0 0 0
## # ... with 7 more variables: 14 <dbl>, 15 <dbl>, 16 <dbl>, 17 <dbl>, 18 <dbl>,
## # 19 <dbl>, 20 <dbl>
kmeans on this dataset. I’ve used the dplyr pipe notation to run kmeans on the data above with “title” removed. augment is a function from the tidymodels package that adds the cluster labels identified by kmeans to the rows in the original dataset.kclust <- movie_mat %>%
select(-title) %>%
kmeans(centers = 10)
movie_mat <- augment(kclust, movie_mat) # creates column ".cluster" with cluster label
kclust <- tidy(kclust)
movie_mat %>%
select(title, .cluster) %>%
arrange(.cluster)
## # A tibble: 50 x 2
## title .cluster
## <chr> <fct>
## 1 Forrest Gump 1
## 2 Schindler's List 1
## 3 Silence of the Lambs, The 1
## 4 Braveheart 2
## 5 Apollo 13 2
## 6 Speed 2
## 7 Fugitive, The 2
## 8 Jurassic Park 2
## 9 Terminator 2: Judgment Day 2
## 10 Dances with Wolves 2
## # ... with 40 more rows
kclust_long <- kclust %>%
pivot_longer(`2`:`671`, names_to = "userId", values_to = "rating")
ggplot(kclust_long) +
geom_bar(
aes(x = reorder(userId, rating), y = rating),
stat = "identity"
) +
facet_grid(cluster ~ .) +
labs(x = "Users (sorted)", y = "Rating") +
theme(
axis.text.x = element_blank(),
axis.text.y = element_text(size = 5),
strip.text.y = element_text(angle = 0)
)
We can visualize each cluster by seeing the average ratings each user gave to the movies in that cluster (this is the definition of the centroid). An alternative visualization strategy would be to show a heatmap – we’ll discuss this soon in the superheat lecture.
It’s often of interest to relate the cluster assignments to complementary data, to see whether the clustering reflects any previously known differences between the observations, which weren’t directly used in the clustering algorithm.
Be cautious: Outliers, nonspherical shapes, and variations in density can throw off \(K\)-means.
The difficulty that variations in density poses to k-means, from Cluster Analysis using K-Means Explained.
library("dplyr")
library("ggplot2")
library("ggraph")
library("knitr")
library("readr")
library("robservable")
library("tidygraph")
theme_set(theme_graph())
In reality, data are rarely separated into a clear number of homogeneous clusters. More often, even once a cluster formed, it’s possible to identify a few subclusters. For example, if you initially clustered movies into “drama” and “scifi”, you might be able to further refine the scifi cluster into “time travel” and “aliens.”
\(K\)-means only allows clustering at a single level of magnification. To instead simultaneously cluster across scales, you can use an approach called hierarchical clustering. As a first observation, note that a tree can be used to implicitly store many clusterings at once. You can get a standard clustering by cutting the tree at some level.
We can recover clusters at different levels of granularity, by cutting a hierarchical clustering tree.
robservable("@mbostock/tree-of-life", height = 1150)
Elaborating on this analogy, the leaves of a hierarchical clustering tree are the original observations. The more recently two nodes share a common ancestor, the more similar those observations are.
The specific algorithm proceeds as follows,
At initialization, the hierarchical clustering routine has a cluster for each observation.
Next, the two closest observations are merged into one cluster. This is the first merge point on the tree.
We continue this at the next iteration, though this time we have compute the pairwise distance between all clusters, not observations (technically, all the observations were their own cluster at the first step, and in both cases, we compare the pairwise distances between clusters).
We can continue this process…
… and eventually we will construct the entire tree.
In R, this can be accomplished by using the hclust function. First, we compute the distances between all pairs of observations (this provides the similarities used in the algorithm). Then, we apply hclust to the matrix of pairwise distances.
We apply this to a movie ratings dataset. Movies are considered similar if they tend to receive similar ratings across all audience members. The result is visualized below.
movies_mat <- read_csv("https://uwmadison.box.com/shared/static/wj1ln9xtigaoubbxow86y2gqmqcsu2jk.csv")
##
## -- Column specification --------------------------------------------------------
## cols(
## .default = col_double(),
## title = col_character()
## )
## i Use `spec()` for the full column specifications.
D <- movies_mat %>%
column_to_rownames(var = "title") %>%
dist()
hclust_result <- hclust(D)
plot(hclust_result, cex = 0.5)
as_tbl_graph function from the network and trees lectures.hclust_graph <- as_tbl_graph(hclust_result, height = height)
hclust_graph <- hclust_graph %>%
mutate(height = ifelse(height == 0, 27, height)) # shorten the final edge
hclust_graph
## # A tbl_graph: 99 nodes and 98 edges
## #
## # A rooted tree
## #
## # Node Data: 99 x 4 (active)
## height leaf label members
## <dbl> <lgl> <chr> <int>
## 1 27 TRUE "Schindler's List" 1
## 2 27 TRUE "Forrest Gump" 1
## 3 27 TRUE "Shawshank Redemption, The" 1
## 4 27 TRUE "Pulp Fiction" 1
## 5 27 TRUE "Silence of the Lambs, The" 1
## 6 58.7 FALSE "" 2
## # ... with 93 more rows
## #
## # Edge Data: 98 x 2
## from to
## <int> <int>
## 1 6 4
## 2 6 5
## 3 7 3
## # ... with 95 more rows
ggraph(hclust_graph, "dendrogram", height = height, circular = TRUE) +
geom_edge_elbow() +
geom_node_text(aes(label = label), size = 4) +
coord_fixed()
cluster_df <- cutree(hclust_result, k = 10) %>% # try changing K and regenerating the graph below
tibble(label = names(.), cluster = as.factor(.))
cluster_df
## # A tibble: 50 x 3
## . label cluster
## <int> <chr> <fct>
## 1 1 Seven (a.k.a. Se7en) 1
## 2 1 Usual Suspects, The 1
## 3 2 Braveheart 2
## 4 2 Apollo 13 2
## 5 3 Pulp Fiction 3
## 6 4 Forrest Gump 4
## 7 2 Lion King, The 2
## 8 2 Mask, The 2
## 9 2 Speed 2
## 10 2 Fugitive, The 2
## # ... with 40 more rows
# colors chosen using https://medialab.github.io/iwanthue/
cols <- c("#51b48c", "#cf3d6e", "#7ab743", "#7b62cb", "#c49644", "#c364b9", "#6a803a", "#688dcd", "#c95a38", "#c26b7e")
hclust_graph %>%
left_join(cluster_df) %>%
ggraph("dendrogram", height = height, circular = TRUE) +
geom_edge_elbow() +
geom_node_text(aes(label = label, col = cluster), size = 4) +
coord_fixed() +
scale_color_manual(values = cols) +
theme(legend.position = "none")
## Joining, by = "label"
library("dplyr")
library("ggplot2")
library("readr")
library("superheat")
library("tibble")
theme_set(theme_minimal())
The direct outputs of a standard clustering algorithim are (a) cluster assignments for each sample, (b) the centroids associated with each cluster. A hierarchical clustering algorithm enriches this output with a tree, which provide (a) and (b) at multiple levels of resolution.
These outputs can be used to improve visualizations. For example, they can be used to define small multiples, faceting across clusters. One especially common idea is to reorder the rows of a heatmap using the results of a clustering, and this is the subject of these notes.
In a heatmap, each mark (usually a small tile) corresponds to an entry of a matrix. The \(x\)-coordinate of the mark encodes the index of the observation, while the \(y\)-coordinate encodes the index of the feature. The color of each tile represents the value of that entry. For example, here are the first few rows of the movies data, along with the corresponding heatmap, made using the superheat package.
movies_mat <- read_csv("https://uwmadison.box.com/shared/static/wj1ln9xtigaoubbxow86y2gqmqcsu2jk.csv") %>%
column_to_rownames(var = "title")
##
## -- Column specification --------------------------------------------------------
## cols(
## .default = col_double(),
## title = col_character()
## )
## i Use `spec()` for the full column specifications.
cols <- c('#f6eff7','#bdc9e1','#67a9cf','#1c9099','#016c59')
superheat(movies_mat, left.label.text.size = 4, heat.pal = cols, heat.lim = c(0, 5))
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
movies_clust <- movies_mat %>%
kmeans(centers = 10)
users_clust <- movies_mat %>%
t() %>%
kmeans(centers = 10)
superheat(
movies_mat,
left.label.text.size = 4,
order.rows = order(movies_clust$cluster),
order.cols = order(users_clust$cluster),
heat.pal = cols,
heat.lim = c(0, 5)
)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
yr) to encode the total number of ratings given to that movie. The yr.obs.cols allows us to change the color of each point in the adjacent plot. In this example, we change color depending on which cluster the movie was found to belong to.cluster_cols <- c('#8dd3c7','#ccebc5','#bebada','#fb8072','#80b1d3','#fdb462','#b3de69','#fccde5','#d9d9d9','#bc80bd')
superheat(
movies_mat,
left.label.text.size = 4,
order.rows = order(movies_clust$cluster),
order.cols = order(users_clust$cluster),
heat.pal = cols,
heat.lim = c(0, 5),
yr = rowSums(movies_mat > 0),
yr.axis.name = "Number of Ratings",
yr.obs.col = cluster_cols[movies_clust$cluster],
yr.plot.type = "bar"
)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
pretty.order.rows and pretty.order.cols arguments use hierarchical clustering to reorder the heatmap.superheat(
movies_mat,
left.label.text.size = 4,
pretty.order.cols = TRUE,
pretty.order.rows = TRUE,
heat.pal = cols,
heat.lim = c(0, 5)
)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
pretty.order.rows and pretty.order.cols can be also visualized.superheat(
movies_mat,
left.label.text.size = 4,
pretty.order.cols = TRUE,
pretty.order.rows = TRUE,
row.dendrogram = TRUE,
col.dendrogram = TRUE,
heat.pal = cols,
heat.lim = c(0, 5)
)
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
## Warning in regularize.values(x, y, ties, missing(ties), na.rm = na.rm):
## collapsing to unique 'x' values
library("cluster")
library("stringr")
library("dplyr")
library("tidymodels")
library("readr")
library("ggplot2")
theme_set(theme_bw())
set.seed(123)
Clustering algorithms usually require the number of clusters \(K\) as an argument. How should it be chosen?
There are many possible criteria, but one common approach is to compute the silhouette statistic. It is a statistic that can be computed for each observation in a dataset, measuring how strongly it is tied to its assigned cluster. If a whole cluster has large silhouette statistics, then that cluster is well-defined and clearly isolated other clusters.
The plots below illustrate the computation of silhouette statistics for a clustering of the penguins dataset that used \(K = 3\). To set up, we first need to cluster the penguins dataset. The idea is the same as in the \(K\)-means notes, but we encapsulate the code in a function, so that we can easily extract data for different values of \(K\).
penguins <- read_csv("https://uwmadison.box.com/shared/static/ijh7iipc9ect1jf0z8qa2n3j7dgem1gh.csv") %>%
na.omit() %>%
mutate(id = row_number())
##
## -- Column specification --------------------------------------------------------
## cols(
## species = col_character(),
## island = col_character(),
## bill_length_mm = col_double(),
## bill_depth_mm = col_double(),
## flipper_length_mm = col_double(),
## body_mass_g = col_double(),
## sex = col_character(),
## year = col_double()
## )
cluster_penguins <- function(penguins, K) {
x <- penguins %>%
select(matches("length|depth|mass")) %>%
scale()
kmeans(x, center = K) %>%
augment(penguins) %>% # creates column ".cluster" with cluster label
mutate(silhouette = silhouette(as.integer(.cluster), dist(x))[, "sil_width"])
}
cur_id <- 2
penguins3 <- cluster_penguins(penguins, K = 3)
obs_i <- penguins3 %>%
filter(id == cur_id)
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
geom_point(data = obs_i, size = 5, col = "black") +
geom_point() +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1))
The observation on which we will compute the silhouette statistic.
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
geom_segment(
data = penguins3 %>% filter(.cluster == obs_i$.cluster),
aes(xend = obs_i$bill_length_mm, yend = obs_i$bill_depth_mm),
size = 0.6, alpha = 0.3
) +
geom_point(data = obs_i, size = 5, col = "black") +
geom_point() +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1)) +
labs(title = expression(paste("Distances used for ", a[i])))
The average distance between the target observation and all others in the same cluster.
ggplot(penguins3, aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster)) +
geom_segment(
data = penguins3 %>% filter(.cluster != obs_i$.cluster),
aes(xend = obs_i$bill_length_mm, yend = obs_i$bill_depth_mm, col = .cluster),
size = 0.5, alpha = 0.3
) +
geom_point(data = obs_i, size = 5, col = "black") +
geom_point() +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1)) +
labs(title = expression(paste("Distances used for ", b[i][1], " and ", b[i][2])))
The average distance between the target observation and all others in different clusters.
The silhouette statistic for observation \(i\) is derived from the relative lengths of the orange vs. green segments. Formally, the silhouette statistic for observation \(i\) is \(s_{i}:= \frac{b_{i} - a_{i}}{\max\left({a_{i}, b_{i}}\right)}\). This number is close to 1 if the orange segments are much longer than the green segments, close to 0 if the segments are about the same size, and close to -1 if the the orange segments are much shorter than the green segments2.
The median of these \(s_{i}\) for all observations within cluster \(k\) is a measure of how well-defined cluster \(k\) is overall. The higher this number, the more well-defined the cluster.
Denote the median of the silhouette statistics within cluster \(k\) by \(SS_{k}\). A measure how good a choice of \(K\) is can be determined by the median of these medians: \(\text{Quality}(K) := \text{median}_{k = 1 \dots, K} SS_{k}\).
In particular, this can be used to define (a) a good cut point in a hierarchical clustering or (b) a point at which a cluster should no longer be split into subgroups.
In R, we can use the silhouette function from the cluster package to compute the silhouette statistic. The syntax is silhouette(cluster_labels, pairwise_distances) where cluster_labels is a vector of (integer) cluster ID’s for each observation and pairwise_distances gives the lengths of the segments between all pairs of observations. An example of this function’s usage is given in the function at the start of the illustration.
This is what the silhouette statistic looks like in the penguins dataset when we choose 3 clusters. The larger points have lower silhouette statistics. This points between clusters 2 and 3 have large silhouette statistics because those two clusters blend into one another.
ggplot(penguins3) +
geom_point(aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster, size = silhouette)) +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1))
The silhouette statistics on the Palmers Penguins dataset, when using \(K\)-means with \(K = 3\).
ggplot(penguins3) +
geom_histogram(aes(x = silhouette), binwidth = 0.05) +
facet_grid(~ .cluster)
The per-cluster histograms of silhouette statistics summarize how well-defined each cluster is.
penguins4 <- cluster_penguins(penguins, K = 4)
ggplot(penguins4) +
geom_point(aes(x = bill_length_mm, y = bill_depth_mm, col = .cluster, size = silhouette)) +
scale_color_brewer(palette = "Set2") +
scale_size(range = c(4, 1))
We can repeat the same exercise, but with \(K = 4\) clusters instead.
ggplot(penguins4) +
geom_histogram(aes(x = silhouette), binwidth = 0.05) +
facet_grid(~ .cluster)
library("MASS")
library("Matrix")
library("dplyr")
library("ggplot2")
library("pdist")
library("superheat")
library("tidyr")
library(knitr)
theme_set(theme_minimal())
set.seed(1234)
One of the fundamental principles in statistics is that, no matter how the experiment / study was conducted, if we ran it again, we would get different results. More formally, sampling variability creates uncertainty in our inferences.
How should we think about sampling variability in the context of clustering? This is a tricky problem, because you can permute the labels of the clusters without changing the meaning of the clustering. However, it is possible to measure and visualize the stability of a point’s cluster assignment.
To make this less abstract, consider an example. A study has found a collection of genes that are differentially expressed between patients with two different subtypes of a disease. There is an interest in clustering genes that have similar expression profiles across all patients — these genes probably belong to similar biological processes.
Once you run the clustering, how sure can you be that, if the study would run again, you would recover a similar clustering? Are there some genes that you are sure belong to a particular cluster? Are there some that lie between two clusters?
To illustrate, consider the simulated dataset below. Imagine that the rows are patients, the column are genes, and the colors are the expression levels of genes within patients. There are 5 clusters of genes here (columns 1 - 20 are cluster 1, 21 - 41 are cluster 2, …). The first two clusters are only weakly visible, while the last three stand out strongly.
n_per <- 20
p <- n_per * 5
Sigma1 <- diag(2) %x% matrix(rep(0.3, n_per ** 2), nrow = n_per)
Sigma2 <- diag(3) %x% matrix(rep(0.6, n_per ** 2), nrow = n_per)
Sigma <- bdiag(Sigma1, Sigma2)
diag(Sigma) <- 1
mu <- rep(0, 100)
x <- mvrnorm(25, mu, Sigma)
cols <- c('#f6eff7','#bdc9e1','#67a9cf','#1c9099','#016c59')
superheat(
x,
pretty.order.rows = TRUE,
bottom.label = "none",
heat.pal = cols,
left.label.text.size = 3,
legend = FALSE
)
A simulated clustering of genes (columns) across rows (patients).
K <- 5
B <- 1000
cluster_profiles <- kmeans(t(x), centers = K)$centers
cluster_probs <- matrix(nrow = ncol(x), ncol = B)
for (b in seq_len(B)) {
b_ix <- sample(nrow(x), replace = TRUE)
dists <- as.matrix(pdist(t(x[b_ix, ]), cluster_profiles[, b_ix]))
cluster_probs[, b] <- apply(dists, 1, which.min)
}
cluster_probs <- as_tibble(cluster_probs) %>%
mutate(gene = row_number()) %>%
pivot_longer(-gene, names_to = "b", values_to = "cluster")
## Warning: The `x` argument of `as_tibble.matrix()` must have unique column names if `.name_repair` is omitted as of tibble 2.0.0.
## Using compatibility `.name_repair`.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_warnings()` to see where this warning was generated.
cluster_probs <- cluster_probs %>%
mutate(cluster = as.factor(cluster)) %>%
group_by(gene, cluster) %>%
summarise(prob = n() / B)
## `summarise()` has grouped output by 'gene'. You can override using the `.groups` argument.
cluster_probs
## # A tibble: 267 x 3
## # Groups: gene [100]
## gene cluster prob
## <int> <fct> <dbl>
## 1 1 2 0.956
## 2 1 3 0.041
## 3 1 5 0.003
## 4 2 2 0.778
## 5 2 5 0.222
## 6 3 1 0.001
## 7 3 2 0.978
## 8 3 5 0.021
## 9 4 1 0.001
## 10 4 2 0.689
## # ... with 257 more rows
ggplot(cluster_probs) +
geom_bar(aes(y = as.factor(gene), x = prob, col = cluster, fill = cluster), stat = "identity") +
scale_fill_brewer(palette = "Set2") +
scale_color_brewer(palette = "Set2") +
scale_x_continuous(expand = c(0, 0)) +
labs(y = "Gene", x = "Proportion") +
theme(
axis.ticks.y = element_blank(),
axis.text.y = element_text(size = 7),
legend.position = "bottom"
)